{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import DeepPurpose.DTI as models\n",
    "from DeepPurpose.utils import *\n",
    "from DeepPurpose.dataset import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Beginning Processing...\n",
      "Beginning to extract zip file...\n",
      "Done!\n",
      "in total: 118254 drug-target pairs\n",
      "encoding drug...\n",
      "unique drugs: 2068\n",
      "drug encoding finished...\n",
      "encoding protein...\n",
      "unique target sequence: 229\n",
      "protein encoding finished...\n",
      "splitting dataset...\n",
      "Done.\n",
      "cost about 39 seconds\n"
     ]
    }
   ],
   "source": [
    "from time import time\n",
    "\n",
    "t1 = time()\n",
    "X_drug, X_target, y = load_process_KIBA('./data/', binary=False)\n",
    "\n",
    "drug_encoding = 'CNN'\n",
    "target_encoding = 'Transformer'\n",
    "train, val, test = data_process(X_drug, X_target, y, \n",
    "                                drug_encoding, target_encoding, \n",
    "                                split_method='random',frac=[0.7,0.1,0.2])\n",
    "\n",
    "# use the parameters setting provided in the paper: https://arxiv.org/abs/1801.10193\n",
    "config = generate_config(drug_encoding = drug_encoding, \n",
    "                         target_encoding = target_encoding, \n",
    "                         cls_hidden_dims = [1024,1024,512], \n",
    "                         train_epoch = 100, \n",
    "                         test_every_X_epoch = 10, \n",
    "                         LR = 0.001, \n",
    "                         batch_size = 128,\n",
    "                         hidden_dim_drug = 128,\n",
    "                         mpnn_hidden_size = 128,\n",
    "                         mpnn_depth = 3, \n",
    "                         cnn_target_filters = [32,64,96],\n",
    "                         cnn_target_kernels = [4,8,12]\n",
    "                        )\n",
    "model = models.model_initialize(**config)\n",
    "t2 = time()\n",
    "print(\"cost about \" + str(int(t2-t1)) + \" seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Let's use CPU/s!\n",
      "--- Data Preparation ---\n",
      "--- Go for Training ---\n",
      "Training at Epoch 1 iteration 0 with loss 139.439. Total time 0.00222 hours\n",
      "Training at Epoch 1 iteration 100 with loss 0.85528. Total time 0.19111 hours\n"
     ]
    }
   ],
   "source": [
    "model.train(train, val, test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_model('./model_CNN_Transformer_Kiba')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
